import itertools
import json

import torch

from ModularUtils.ControllerConstants import map_dictfill_to_discrete, generate_permutations, get_label_fill, \
    map_fill_to_discrete
from ModularUtils.ControllerModel import get_generated_labels
from ModularUtils.FunctionsDistribution import match_with_true_dist, conditional_prob, calculate_TVD


def compare_interventions(Exp, label_generators, observed_var, intervention, doPrint):

    if len(intervention)==0:
        return 0

    # generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, {}, Exp.label_names, Exp.Synthetic_Sample_Size)
    # y_dims = sum([Exp.label_dim[lb]["feature"] for lb in Exp.label_names])
    # ret = list(generated_labels_dict.values())
    # generated_labels_full = torch.cat(ret, 1).view(-1, y_dims)
    # dims_list = [Exp.label_dim[lb]["feature"] for lb in Exp.label_names]
    # generated_observations = map_fill_to_discrete(Exp, generated_labels_full,dims_list).detach().cpu().numpy().astype(int)
    #sachs_TVD_L2_obs.png
    #sachs_KL_L2_obs.png
    # dims_list = [Exp.label_dim[lb]["feature"] for lb in intervened_var]
    # intevened = generate_permutations(dims_list)
    #
    #
    # # Calculating join distribution of intervening variables
    # indices = [Exp.label_names.index(lb) for lb in intervened_var]
    # chosen_generated_data = generated_observations[:, indices].detach().cpu().numpy().astype(int)
    # interv_variable_queries = {}
    # for (id, obs) in enumerate(intervened_var):
    #     interv_variable_queries[obs] = intevened[:, id]
    # intervention_probability = joint_prob(chosen_generated_data, interv_variable_queries)


    total_tvd= 0
    total_kl= 0
    total_estimated_tvd= 0
    tvd_list=[]
    kl_list = []
    int_prob =[]
    est_list=[]


    generated_labels_dict = get_generated_labels(Exp, label_generators, {}, {}, intervention, observed_var, Exp.Synthetic_Sample_Size)
    chosen_generated_data = map_dictfill_to_discrete(Exp, generated_labels_dict, observed_var)


    with open(Exp.Intv_SCMs+str(intervention)+".txt") as f:
        data = f.read()
    loaded_dist_dict = json.loads(data)
    true_dist_dict = {eval(key):loaded_dist_dict[key] for key in loaded_dist_dict}


    tvd,_= match_with_true_dist(Exp, observed_var, chosen_generated_data, true_dist_dict, "feature", doPrint)
    tvd_list.append(tvd)
    # kl_list.append(tvd)
    print("Intervened:",observed_var, intervention, " tvd:",tvd)

    # prob = intervention_probability[tuple(intervened_query.items())]
    # prob = (1.0 / len(intervention))
    prob = 1     #I am treating each as different distribution and not averaging.
    int_prob.append(prob)
    total_tvd += tvd * prob
    # total_kl += kl * prob


    print("tvd_list & int_prob : ", tuple(zip(tvd_list, int_prob))  )
    # print("est_tvd_list: ", est_list)
    return total_tvd










def compare_conditionals_with_truth(Exp,dataset, label_generators, observed_var, conditioning_var, doPrint):

    true_cond_probs = {}

    dist_dict = {}

    Yperms = generate_permutations(len(observed_var), Exp.obs_state)
    X_perms = generate_permutations(len(conditioning_var), Exp.obs_state)
    for xp in X_perms:
        Xdict = dict(zip(conditioning_var, xp))
        for yp in Yperms:
            Ydict = dict(zip(observed_var, yp))
            YXdict={**Ydict, **Xdict}
            dist_dict[tuple(YXdict.values())]= conditional_prob(Exp, dataset, Ydict, Xdict)


    # _, _, _, true_dist_dict1 = get_synthetic_dist(Exp, observed_var+conditioning_var , {}, load_scm=1)
    # _, _, _, true_dist_dict2 = get_synthetic_dist(Exp, conditioning_var , {}, load_scm=1)
    # print(true_dist_dict1)
    # print(true_dist_dict2)
    # for comb1 in true_dist_dict1:
    #     comb2= [comb1[1]]
    #     print(comb1, true_dist_dict1[comb1],  "& " ,comb2, true_dist_dict2[tuple(comb2)] )
    #     true_dist_dict1[comb1]= true_dist_dict1[comb1]/ true_dist_dict2[tuple(comb2)]

    true_dist_dict = get_cond_probs(Exp, observed_var, conditioning_var, load_scm=1)

    print("Final dist")
    print(true_dist_dict)
    print(dist_dict)


    tvd= calculate_TVD(dist_dict, true_dist_dict, doPrint) / Exp.label_dim
    # dividing by label dimensions cz we are getting the difference of  P(Y|x=0) , P(Y|x=1),..., P(Y|x=label_dim) each of which sums to 1.
    print("P(Y|X tvd", tvd)

    return tvd





def compare_conditionals_within(Exp, dataset, feat, observed_var, conditioning_var, doPrint):

    dist_dict = {}

    dims_list1 = [Exp.label_dim[lb] for lb in observed_var]
    Yperms = generate_permutations(dims_list1)

    dims_list2 = [Exp.label_dim[lb] for lb in conditioning_var]
    X_perms = generate_permutations(dims_list2)

    for xp in X_perms:
        Xdict = dict(zip(conditioning_var, xp))
        for yp in Yperms:
            Ydict = dict(zip(observed_var, yp))
            YXdict={**Ydict, **Xdict}
            dist_dict[tuple(YXdict.values())]= conditional_prob(Exp, dataset, feat, Ydict, Xdict)

    # print("distribution", dist_dict)

    return dist_dict










def compare_interventions_old(Exp,real_dataset, label_generators, observed_var, intervened_var, doPrint):
    intv_labels_fill = []  # X
    current_intv_label = real_dataset[:, 0].type(torch.LongTensor).view(-1, 1).to(Exp.DEVICE)
    filled_intv_label = get_label_fill(Exp.label_dim)[current_intv_label].to(Exp.DEVICE)
    ret = filled_intv_label.view(-1, Exp.label_dim)
    intv_labels_fill.append(ret)
    intv_labels_fill_wdno = torch.cat(intv_labels_fill, 1).to(Exp.DEVICE)


    intevened = generate_permutations(len(intervened_var), Exp.obs_state)


    # Calculating join distribution of intervening variables
    indices = [Exp.label_names.index(lb) for lb in intervened_var]
    chosen_real_data = real_dataset[:, indices].detach().cpu().numpy().astype(int)

    interv_variable_queries = {}
    for (id, obs) in enumerate(intervened_var):
        interv_variable_queries[obs] = intevened[:, id]
    intervention_probability = joint_prob(chosen_real_data, interv_variable_queries)


    total_tvd= 0
    total_kl= 0
    total_estimated_tvd= 0
    tvd_list=[]
    kl_list = []
    int_prob =[]
    est_list=[]

    for interv in intevened:
        intervened_query= dict(zip(intervened_var, interv))
        generated_labels = get_generated_labels(Exp, label_generators, {}, {}, intervened_query, Exp.Synthetic_Sample_Size)
        # generated_labels= get_mechanism_label(Exp, label_generators, {}, {}, "Y", intv_labels_fill_wdno, intv_labels_fill_wdno.shape[0])
        generated_labels = map_fill_to_discrete(Exp, generated_labels).detach().cpu().numpy().astype(int)
        chosen_generated_data = generated_labels

        indices = [Exp.label_names.index(lb) for lb in observed_var]
        chosen_generated_data = generated_labels[:, indices]

        _, _, _, true_dist_dict = get_synthetic_dist(Exp,observed_var, intervened_query, load_scm=1)

        tvd,kl= match_with_true_dist(Exp, observed_var, chosen_generated_data, true_dist_dict, doPrint)

        tvd_list.append(tvd)
        kl_list.append(tvd)
        print("Intervened:",observed_var, intervened_query, " tvd:",tvd, " ,kl:",kl)

        # prob = intervention_probability[tuple(intervened_query.items())]
        prob = (1.0 / len(intevened))
        int_prob.append(prob)
        total_tvd += tvd * prob
        total_kl += kl * prob



    print("tvd_list & int_prob : ", tuple(zip(tvd_list, int_prob))  )
    # print("est_tvd_list: ", est_list)
    return total_tvd, total_kl




def compare_observations(Exp,real_dataset, generated_labels):


    gen_vs_real_tvd ={}
    for r in range(1, Exp.num_labels+1):
        combinations = itertools.combinations(Exp.label_names, r)

        for comb_vars in combinations:
            indices = [Exp.label_names.index(lb) for lb in comb_vars]

            chosen_generated_data = generated_labels[:, indices]
            gen_dist = get_distributions_from_samples(Exp, comb_vars, chosen_generated_data)

            chosen_real_data = real_dataset[:, indices]
            real_dist = get_distributions_from_samples(Exp, comb_vars, chosen_real_data)

            tvd = calculate_TVD(gen_dist, real_dist, doPrint=False)

            gen_vs_real_tvd[comb_vars] = tvd


    return gen_vs_real_tvd



def get_intervtocf(label_generators):

    probterms={}
    trueprobs={}
    for y in range(ControllerConstants.obs_state):
        for w in range(ControllerConstants.obs_state):

            observed_var = ["X2","Y"]
            intervened_query= {"X1":1, "W":w}
            generated_labels = get_generated_labels(label_generators, {}, {}, intervened_query,ControllerConstants.Synthetic_Sample_Size)
            generated_labels = ControllerConstants.map_fill_to_discrete(generated_labels).detach().cpu().numpy().astype(int)
            indices = [ControllerConstants.label_names.index(lb) for lb in observed_var]
            chosen_generated_data = generated_labels[:, indices]
            observed_variable_queries={"X2":[1],"Y":[y]}
            prob1 = joint_prob(chosen_generated_data, observed_variable_queries)
            print("Gan P:",observed_variable_queries, " |do",intervened_query,"=",prob1)

            _, _, _, true_dist_dict = get_synthetic_dist(observed_var, intervened_query, load_scm=1)
            tpl =tuple(sum(list(observed_variable_queries.values()),[]))
            true_prob1= true_dist_dict[tpl]
            print("True interventions:", tpl, true_prob1)



            # 2nd term
            observed_var = ["W"]
            intervened_query = {"X1": 0, "X2": 0}
            generated_labels = get_generated_labels(label_generators, {}, {}, intervened_query, ControllerConstants.Synthetic_Sample_Size)
            generated_labels = ControllerConstants.map_fill_to_discrete(generated_labels).detach().cpu().numpy().astype(int)
            indices = [ControllerConstants.label_names.index(lb) for lb in observed_var]
            chosen_generated_data = generated_labels[:, indices]
            observed_variable_queries = {"W": [w]}
            prob2 = joint_prob(chosen_generated_data, observed_variable_queries)
            print("Gan P:",observed_variable_queries, " |do",intervened_query,"=",prob2)

            _, _, _, true_dist_dict = get_synthetic_dist(observed_var, intervened_query, load_scm=1)
            tpl = tuple(sum(list(observed_variable_queries.values()), []))
            true_prob2 = true_dist_dict[tpl]
            print("True interventions:", tpl, true_prob2)


            probterms[toKey({"W": w,"Y": y})] = list(prob1.values())[0]* list(prob2.values())[0]
            trueprobs[toKey({"W": w,"Y": y})] = true_prob1* true_prob2



    print(probterms)
    print(trueprobs)
    numerator = probterms[toKey({"W":0,"Y":1})] + probterms[toKey({"W":1, "Y":1})]
    denominator = numerator + probterms[toKey({"W":0,"Y":0 })] + probterms[toKey({"W":1,"Y":0})]
    est_cf= numerator/denominator
    print("est_cf",est_cf)

    numerator = trueprobs[toKey({"W": 0, "Y": 1})] + trueprobs[toKey({"W": 1, "Y": 1})]
    denominator = numerator + trueprobs[toKey({"W": 0, "Y": 0})] + trueprobs[toKey({"W": 1, "Y": 0})]
    est_cf = numerator / denominator
    print("true_cf", est_cf)
    #
    print("True counterfactual")
    true_cf= get_cf_dist()
    print("true_cf direct",true_cf)













